sink('./generated/appendix_2/full_nonlinear_samp_correct.txt')
source("./04_simulation_00_setup.R")
source("./04_simulation_01_DGP_full_sims.R")

K <- 10
nboots <- 1

full_sims <- NULL
for(correct in c("bart")) {
  for(total_size in c(1000, 2000, 8000)) {
    for(sampling_model in c("Yes")) {
      set.seed(6725498)
      
      sims <- mclapply(1:nsim, function(i) {
        
        ## Draw data
        dgp <-  dgp_exp_pop(total_size = total_size,
                            K = K, 
                            K_W = 5,
                            sim_index = i,
                            outcome_type = correct)
        
        exp_data <- as.data.frame(dgp$exp_data)
        pop_data <- as.data.frame(dgp$pop_data)
        
        ## Outcome model
        formula_outcome <- Y ~ treatment + X_1 + X_2 + X_3 + X_4 + X_5
        
        ## Sampling model
        if(sampling_model == "Yes") {
          ## Correct weight model
          formula_weights <- S ~ X_1 + X_2 + X_3 + X_4 + X_5  
        } else {
          ## Incorrect weight model
          formula_weights <- S ~ X_1 + X_2 + X_3
        }
        
        
        covariates = names(exp_data)[str_detect(names(exp_data), "X")][1:5]
        
        ## SATE estimate
        sate_est <- data.frame(estimator = "SATE",  type = "SATE",
                               tpate(formula_outcome = formula_outcome,
                                     formula_weights = formula_weights,
                                     exp_data = exp_data, pop_data = pop_data,
                                     treatment = "treatment", id_cluster = exp_data$cluster,
                                     est_type = "ipw", weights_type = "logit",
                                     sims = nboots, 
                                     boot = FALSE,
                                     numCores = 1,
                                     compute_sate = TRUE)$sate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        ## Weighting: IPW
        ipw_logit <- data.frame(estimator = "IPW-logit", type = "weighting",
                                tpate(formula_outcome = formula_outcome,
                                      formula_weights = formula_weights,
                                      exp_data = exp_data, pop_data = pop_data,
                                      treatment = "treatment", id_cluster = exp_data$cluster,
                                      est_type = "ipw", weights_type = "logit",
                                      sims = nboots,
                                      boot = FALSE,
                                      numCores = 1,
                                      compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        ## wLS: IPW
        wls_logit <- data.frame(estimator = "wLS-logit", type = "weighting",
                                tpate(formula_outcome = formula_outcome,
                                      formula_weights = formula_weights,
                                      exp_data = exp_data, pop_data = pop_data,
                                      treatment = "treatment", id_cluster = exp_data$cluster,
                                      est_type = "wls", weights_type = "logit",
                                      sims = nboots, 
                                      boot = FALSE,
                                      numCores = 1,
                                      compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        # Outcome-based: OLS
        out_ols <- data.frame(estimator = "OLS-proj", type = "outcome",
                              tpate(formula_outcome = formula_outcome,
                                    formula_weights = formula_weights,
                                    exp_data = exp_data, pop_data = pop_data,
                                    treatment = "treatment", id_cluster = exp_data$cluster,
                                    est_type = "outcome-ols",
                                    sims = nboots, 
                                    numCores = 1,
                                    boot = TRUE,
                                    compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        # Outcome-based: BART
        out_bart <- data.frame(estimator = "BART-proj", type = "outcome",
                               tpate(formula_outcome = formula_outcome,
                                     formula_weights = formula_weights,
                                     exp_data = exp_data, pop_data = pop_data,
                                     treatment = "treatment", id_cluster = exp_data$cluster,
                                     est_type = "outcome-bart",
                                     sims = nboots, 
                                     boot = FALSE,
                                     numCores = 1,
                                     compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        # DR Estimator -- Aug. OLS
        dr_logit_ols <- data.frame(estimator = "DR-OLS-logit", type = "doubly robust",
                                   tpate(formula_outcome = formula_outcome,
                                         formula_weights = formula_weights,
                                         exp_data = exp_data, pop_data = pop_data,
                                         treatment = "treatment", weights_type = "logit",
                                         est_type = "dr-ols", id_cluster = exp_data$cluster,
                                         sims = nboots, 
                                         boot = FALSE,
                                         numCores = 1,
                                         compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        # DR Estimator -- Aug. BART
        dr_logit_bart <- data.frame(estimator = "DR-BART-logit", type = "doubly robust",
                                    tpate(formula_outcome = formula_outcome,
                                          formula_weights = formula_weights,
                                          exp_data = exp_data, pop_data = pop_data,
                                          treatment = "treatment", weights_type = "logit",
                                          est_type = "dr-bart", id_cluster = exp_data$cluster,
                                          sims = nboots, 
                                          boot = FALSE,
                                          numCores = 1,
                                          compute_sate = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        ## Truth
        PATE <- data.frame(estimator = "PATE", type = "PATE", est = mean(pop_data$Y_1 - pop_data$Y_0))
        
        results = data.frame(sim = i, n = total_size/2, 
                             bind_rows(sate_est,
                                       ipw_logit,
                                       wls_logit,
                                       out_ols,
                                       out_bart,
                                       dr_logit_ols,
                                       dr_logit_bart,
                                       PATE)
        )
        
        return(results)
      }, mc.cores = detectCores() - 1) %>% bind_rows()
      
      sims <- sims %>%
        mutate(correct_model = correct,
               correct_sample = sampling_model)
      
      full_sims <- bind_rows(full_sims, sims)
      
    }
  }
}

full_sims_nonlinear_samp_correct <- full_sims

save(full_sims_nonlinear_samp_correct, file = "./generated/appendix_2/full_sims_nonlinear_samp_correct.RData")
sink()

sink('./generated/appendix_2/full_sims_time_02.txt')
end_time <- Sys.time()
end_time - start_time
sink()